{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rcParams\n",
    "%matplotlib inline\n",
    "%config InlineBackend.figure_format = 'retina'\n",
    "rcParams['figure.figsize'] = (12,6)\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "import os\n",
    "\n",
    "import torch\n",
    "from torch import nn, optim\n",
    "from torchvision import models\n",
    "from nupic.torch.modules import KWinners2d, Flatten\n",
    "from torchsummary import summary\n",
    "\n",
    "from utils import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10, 10])\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "selected index k out of range",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-27-dd7453fdfe10>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      2\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrandn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkthvalue\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ma\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m101\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m: selected index k out of range"
     ]
    }
   ],
   "source": [
    "#\n",
    "a = torch.randn(10,10)\n",
    "print(a.shape)\n",
    "torch.kthvalue(a.view(-1), 101)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setup a simple dataset and model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['ALLOW_THREADS',\n",
       " 'AxisError',\n",
       " 'BUFSIZE',\n",
       " 'CLIP',\n",
       " 'ComplexWarning',\n",
       " 'DataSource',\n",
       " 'ERR_CALL',\n",
       " 'ERR_DEFAULT',\n",
       " 'ERR_IGNORE',\n",
       " 'ERR_LOG',\n",
       " 'ERR_PRINT',\n",
       " 'ERR_RAISE',\n",
       " 'ERR_WARN',\n",
       " 'FLOATING_POINT_SUPPORT',\n",
       " 'FPE_DIVIDEBYZERO',\n",
       " 'FPE_INVALID',\n",
       " 'FPE_OVERFLOW',\n",
       " 'FPE_UNDERFLOW',\n",
       " 'False_',\n",
       " 'Inf',\n",
       " 'Infinity',\n",
       " 'MAXDIMS',\n",
       " 'MAY_SHARE_BOUNDS',\n",
       " 'MAY_SHARE_EXACT',\n",
       " 'MachAr',\n",
       " 'ModuleDeprecationWarning',\n",
       " 'NAN',\n",
       " 'NINF',\n",
       " 'NZERO',\n",
       " 'NaN',\n",
       " 'PINF',\n",
       " 'PZERO',\n",
       " 'RAISE',\n",
       " 'RankWarning',\n",
       " 'SHIFT_DIVIDEBYZERO',\n",
       " 'SHIFT_INVALID',\n",
       " 'SHIFT_OVERFLOW',\n",
       " 'SHIFT_UNDERFLOW',\n",
       " 'ScalarType',\n",
       " 'Tester',\n",
       " 'TooHardError',\n",
       " 'True_',\n",
       " 'UFUNC_BUFSIZE_DEFAULT',\n",
       " 'UFUNC_PYVALS_NAME',\n",
       " 'VisibleDeprecationWarning',\n",
       " 'WRAP',\n",
       " '_NoValue',\n",
       " '_UFUNC_API',\n",
       " '__NUMPY_SETUP__',\n",
       " '__all__',\n",
       " '__builtins__',\n",
       " '__cached__',\n",
       " '__config__',\n",
       " '__doc__',\n",
       " '__file__',\n",
       " '__git_revision__',\n",
       " '__loader__',\n",
       " '__mkl_version__',\n",
       " '__name__',\n",
       " '__package__',\n",
       " '__path__',\n",
       " '__spec__',\n",
       " '__version__',\n",
       " '_add_newdoc_ufunc',\n",
       " '_arg',\n",
       " '_distributor_init',\n",
       " '_globals',\n",
       " '_mat',\n",
       " '_mklinit',\n",
       " '_pytesttester',\n",
       " 'abs',\n",
       " 'absolute',\n",
       " 'absolute_import',\n",
       " 'add',\n",
       " 'add_docstring',\n",
       " 'add_newdoc',\n",
       " 'add_newdoc_ufunc',\n",
       " 'alen',\n",
       " 'all',\n",
       " 'allclose',\n",
       " 'alltrue',\n",
       " 'amax',\n",
       " 'amin',\n",
       " 'angle',\n",
       " 'any',\n",
       " 'append',\n",
       " 'apply_along_axis',\n",
       " 'apply_over_axes',\n",
       " 'arange',\n",
       " 'arccos',\n",
       " 'arccosh',\n",
       " 'arcsin',\n",
       " 'arcsinh',\n",
       " 'arctan',\n",
       " 'arctan2',\n",
       " 'arctanh',\n",
       " 'argmax',\n",
       " 'argmin',\n",
       " 'argpartition',\n",
       " 'argsort',\n",
       " 'argwhere',\n",
       " 'around',\n",
       " 'array',\n",
       " 'array2string',\n",
       " 'array_equal',\n",
       " 'array_equiv',\n",
       " 'array_repr',\n",
       " 'array_split',\n",
       " 'array_str',\n",
       " 'asanyarray',\n",
       " 'asarray',\n",
       " 'asarray_chkfinite',\n",
       " 'ascontiguousarray',\n",
       " 'asfarray',\n",
       " 'asfortranarray',\n",
       " 'asmatrix',\n",
       " 'asscalar',\n",
       " 'atleast_1d',\n",
       " 'atleast_2d',\n",
       " 'atleast_3d',\n",
       " 'average',\n",
       " 'bartlett',\n",
       " 'base_repr',\n",
       " 'binary_repr',\n",
       " 'bincount',\n",
       " 'bitwise_and',\n",
       " 'bitwise_not',\n",
       " 'bitwise_or',\n",
       " 'bitwise_xor',\n",
       " 'blackman',\n",
       " 'block',\n",
       " 'bmat',\n",
       " 'bool',\n",
       " 'bool8',\n",
       " 'bool_',\n",
       " 'broadcast',\n",
       " 'broadcast_arrays',\n",
       " 'broadcast_to',\n",
       " 'busday_count',\n",
       " 'busday_offset',\n",
       " 'busdaycalendar',\n",
       " 'byte',\n",
       " 'byte_bounds',\n",
       " 'bytes0',\n",
       " 'bytes_',\n",
       " 'c_',\n",
       " 'can_cast',\n",
       " 'cast',\n",
       " 'cbrt',\n",
       " 'cdouble',\n",
       " 'ceil',\n",
       " 'cfloat',\n",
       " 'char',\n",
       " 'character',\n",
       " 'chararray',\n",
       " 'choose',\n",
       " 'clip',\n",
       " 'clongdouble',\n",
       " 'clongfloat',\n",
       " 'column_stack',\n",
       " 'common_type',\n",
       " 'compare_chararrays',\n",
       " 'compat',\n",
       " 'complex',\n",
       " 'complex128',\n",
       " 'complex256',\n",
       " 'complex64',\n",
       " 'complex_',\n",
       " 'complexfloating',\n",
       " 'compress',\n",
       " 'concatenate',\n",
       " 'conj',\n",
       " 'conjugate',\n",
       " 'convolve',\n",
       " 'copy',\n",
       " 'copysign',\n",
       " 'copyto',\n",
       " 'core',\n",
       " 'corrcoef',\n",
       " 'correlate',\n",
       " 'cos',\n",
       " 'cosh',\n",
       " 'count_nonzero',\n",
       " 'cov',\n",
       " 'cross',\n",
       " 'csingle',\n",
       " 'ctypeslib',\n",
       " 'cumprod',\n",
       " 'cumproduct',\n",
       " 'cumsum',\n",
       " 'datetime64',\n",
       " 'datetime_as_string',\n",
       " 'datetime_data',\n",
       " 'deg2rad',\n",
       " 'degrees',\n",
       " 'delete',\n",
       " 'deprecate',\n",
       " 'deprecate_with_doc',\n",
       " 'diag',\n",
       " 'diag_indices',\n",
       " 'diag_indices_from',\n",
       " 'diagflat',\n",
       " 'diagonal',\n",
       " 'diff',\n",
       " 'digitize',\n",
       " 'disp',\n",
       " 'divide',\n",
       " 'division',\n",
       " 'divmod',\n",
       " 'dot',\n",
       " 'double',\n",
       " 'dsplit',\n",
       " 'dstack',\n",
       " 'dtype',\n",
       " 'e',\n",
       " 'ediff1d',\n",
       " 'einsum',\n",
       " 'einsum_path',\n",
       " 'emath',\n",
       " 'empty',\n",
       " 'empty_like',\n",
       " 'equal',\n",
       " 'errstate',\n",
       " 'euler_gamma',\n",
       " 'exp',\n",
       " 'exp2',\n",
       " 'expand_dims',\n",
       " 'expm1',\n",
       " 'extract',\n",
       " 'eye',\n",
       " 'fabs',\n",
       " 'fastCopyAndTranspose',\n",
       " 'fft',\n",
       " 'fill_diagonal',\n",
       " 'find_common_type',\n",
       " 'finfo',\n",
       " 'fix',\n",
       " 'flatiter',\n",
       " 'flatnonzero',\n",
       " 'flexible',\n",
       " 'flip',\n",
       " 'fliplr',\n",
       " 'flipud',\n",
       " 'float',\n",
       " 'float128',\n",
       " 'float16',\n",
       " 'float32',\n",
       " 'float64',\n",
       " 'float_',\n",
       " 'float_power',\n",
       " 'floating',\n",
       " 'floor',\n",
       " 'floor_divide',\n",
       " 'fmax',\n",
       " 'fmin',\n",
       " 'fmod',\n",
       " 'format_float_positional',\n",
       " 'format_float_scientific',\n",
       " 'format_parser',\n",
       " 'frexp',\n",
       " 'frombuffer',\n",
       " 'fromfile',\n",
       " 'fromfunction',\n",
       " 'fromiter',\n",
       " 'frompyfunc',\n",
       " 'fromregex',\n",
       " 'fromstring',\n",
       " 'full',\n",
       " 'full_like',\n",
       " 'fv',\n",
       " 'gcd',\n",
       " 'generic',\n",
       " 'genfromtxt',\n",
       " 'geomspace',\n",
       " 'get_array_wrap',\n",
       " 'get_include',\n",
       " 'get_printoptions',\n",
       " 'getbufsize',\n",
       " 'geterr',\n",
       " 'geterrcall',\n",
       " 'geterrobj',\n",
       " 'gradient',\n",
       " 'greater',\n",
       " 'greater_equal',\n",
       " 'half',\n",
       " 'hamming',\n",
       " 'hanning',\n",
       " 'heaviside',\n",
       " 'histogram',\n",
       " 'histogram2d',\n",
       " 'histogram_bin_edges',\n",
       " 'histogramdd',\n",
       " 'hsplit',\n",
       " 'hstack',\n",
       " 'hypot',\n",
       " 'i0',\n",
       " 'identity',\n",
       " 'iinfo',\n",
       " 'imag',\n",
       " 'in1d',\n",
       " 'index_exp',\n",
       " 'indices',\n",
       " 'inexact',\n",
       " 'inf',\n",
       " 'info',\n",
       " 'infty',\n",
       " 'inner',\n",
       " 'insert',\n",
       " 'int',\n",
       " 'int0',\n",
       " 'int16',\n",
       " 'int32',\n",
       " 'int64',\n",
       " 'int8',\n",
       " 'int_',\n",
       " 'int_asbuffer',\n",
       " 'intc',\n",
       " 'integer',\n",
       " 'interp',\n",
       " 'intersect1d',\n",
       " 'intp',\n",
       " 'invert',\n",
       " 'ipmt',\n",
       " 'irr',\n",
       " 'is_busday',\n",
       " 'isclose',\n",
       " 'iscomplex',\n",
       " 'iscomplexobj',\n",
       " 'isfinite',\n",
       " 'isfortran',\n",
       " 'isin',\n",
       " 'isinf',\n",
       " 'isnan',\n",
       " 'isnat',\n",
       " 'isneginf',\n",
       " 'isposinf',\n",
       " 'isreal',\n",
       " 'isrealobj',\n",
       " 'isscalar',\n",
       " 'issctype',\n",
       " 'issubclass_',\n",
       " 'issubdtype',\n",
       " 'issubsctype',\n",
       " 'iterable',\n",
       " 'ix_',\n",
       " 'kaiser',\n",
       " 'kron',\n",
       " 'lcm',\n",
       " 'ldexp',\n",
       " 'left_shift',\n",
       " 'less',\n",
       " 'less_equal',\n",
       " 'lexsort',\n",
       " 'lib',\n",
       " 'linalg',\n",
       " 'linspace',\n",
       " 'little_endian',\n",
       " 'load',\n",
       " 'loads',\n",
       " 'loadtxt',\n",
       " 'log',\n",
       " 'log10',\n",
       " 'log1p',\n",
       " 'log2',\n",
       " 'logaddexp',\n",
       " 'logaddexp2',\n",
       " 'logical_and',\n",
       " 'logical_not',\n",
       " 'logical_or',\n",
       " 'logical_xor',\n",
       " 'logspace',\n",
       " 'long',\n",
       " 'longcomplex',\n",
       " 'longdouble',\n",
       " 'longfloat',\n",
       " 'longlong',\n",
       " 'lookfor',\n",
       " 'ma',\n",
       " 'mafromtxt',\n",
       " 'mask_indices',\n",
       " 'mat',\n",
       " 'math',\n",
       " 'matmul',\n",
       " 'matrix',\n",
       " 'matrixlib',\n",
       " 'max',\n",
       " 'maximum',\n",
       " 'maximum_sctype',\n",
       " 'may_share_memory',\n",
       " 'mean',\n",
       " 'median',\n",
       " 'memmap',\n",
       " 'meshgrid',\n",
       " 'mgrid',\n",
       " 'min',\n",
       " 'min_scalar_type',\n",
       " 'minimum',\n",
       " 'mintypecode',\n",
       " 'mirr',\n",
       " 'mod',\n",
       " 'modf',\n",
       " 'moveaxis',\n",
       " 'msort',\n",
       " 'multiply',\n",
       " 'nan',\n",
       " 'nan_to_num',\n",
       " 'nanargmax',\n",
       " 'nanargmin',\n",
       " 'nancumprod',\n",
       " 'nancumsum',\n",
       " 'nanmax',\n",
       " 'nanmean',\n",
       " 'nanmedian',\n",
       " 'nanmin',\n",
       " 'nanpercentile',\n",
       " 'nanprod',\n",
       " 'nanquantile',\n",
       " 'nanstd',\n",
       " 'nansum',\n",
       " 'nanvar',\n",
       " 'nbytes',\n",
       " 'ndarray',\n",
       " 'ndenumerate',\n",
       " 'ndfromtxt',\n",
       " 'ndim',\n",
       " 'ndindex',\n",
       " 'nditer',\n",
       " 'negative',\n",
       " 'nested_iters',\n",
       " 'newaxis',\n",
       " 'nextafter',\n",
       " 'nonzero',\n",
       " 'not_equal',\n",
       " 'nper',\n",
       " 'npv',\n",
       " 'numarray',\n",
       " 'number',\n",
       " 'obj2sctype',\n",
       " 'object',\n",
       " 'object0',\n",
       " 'object_',\n",
       " 'ogrid',\n",
       " 'oldnumeric',\n",
       " 'ones',\n",
       " 'ones_like',\n",
       " 'outer',\n",
       " 'packbits',\n",
       " 'pad',\n",
       " 'partition',\n",
       " 'percentile',\n",
       " 'pi',\n",
       " 'piecewise',\n",
       " 'place',\n",
       " 'pmt',\n",
       " 'poly',\n",
       " 'poly1d',\n",
       " 'polyadd',\n",
       " 'polyder',\n",
       " 'polydiv',\n",
       " 'polyfit',\n",
       " 'polyint',\n",
       " 'polymul',\n",
       " 'polynomial',\n",
       " 'polysub',\n",
       " 'polyval',\n",
       " 'positive',\n",
       " 'power',\n",
       " 'ppmt',\n",
       " 'print_function',\n",
       " 'printoptions',\n",
       " 'prod',\n",
       " 'product',\n",
       " 'promote_types',\n",
       " 'ptp',\n",
       " 'put',\n",
       " 'put_along_axis',\n",
       " 'putmask',\n",
       " 'pv',\n",
       " 'quantile',\n",
       " 'r_',\n",
       " 'rad2deg',\n",
       " 'radians',\n",
       " 'random',\n",
       " 'rank',\n",
       " 'rate',\n",
       " 'ravel',\n",
       " 'ravel_multi_index',\n",
       " 'real',\n",
       " 'real_if_close',\n",
       " 'rec',\n",
       " 'recarray',\n",
       " 'recfromcsv',\n",
       " 'recfromtxt',\n",
       " 'reciprocal',\n",
       " 'record',\n",
       " 'remainder',\n",
       " 'repeat',\n",
       " 'require',\n",
       " 'reshape',\n",
       " 'resize',\n",
       " 'result_type',\n",
       " 'right_shift',\n",
       " 'rint',\n",
       " 'roll',\n",
       " 'rollaxis',\n",
       " 'roots',\n",
       " 'rot90',\n",
       " 'round',\n",
       " 'round_',\n",
       " 'row_stack',\n",
       " 's_',\n",
       " 'safe_eval',\n",
       " 'save',\n",
       " 'savetxt',\n",
       " 'savez',\n",
       " 'savez_compressed',\n",
       " 'sctype2char',\n",
       " 'sctypeDict',\n",
       " 'sctypeNA',\n",
       " 'sctypes',\n",
       " 'searchsorted',\n",
       " 'select',\n",
       " 'set_numeric_ops',\n",
       " 'set_printoptions',\n",
       " 'set_string_function',\n",
       " 'setbufsize',\n",
       " 'setdiff1d',\n",
       " 'seterr',\n",
       " 'seterrcall',\n",
       " 'seterrobj',\n",
       " 'setxor1d',\n",
       " 'shape',\n",
       " 'shares_memory',\n",
       " 'short',\n",
       " 'show_config',\n",
       " 'sign',\n",
       " 'signbit',\n",
       " 'signedinteger',\n",
       " 'sin',\n",
       " 'sinc',\n",
       " 'single',\n",
       " 'singlecomplex',\n",
       " 'sinh',\n",
       " 'size',\n",
       " 'sometrue',\n",
       " 'sort',\n",
       " 'sort_complex',\n",
       " 'source',\n",
       " 'spacing',\n",
       " 'split',\n",
       " 'sqrt',\n",
       " 'square',\n",
       " 'squeeze',\n",
       " 'stack',\n",
       " 'std',\n",
       " 'str',\n",
       " 'str0',\n",
       " 'str_',\n",
       " 'string_',\n",
       " 'subtract',\n",
       " 'sum',\n",
       " 'swapaxes',\n",
       " 'sys',\n",
       " 'take',\n",
       " 'take_along_axis',\n",
       " 'tan',\n",
       " 'tanh',\n",
       " 'tensordot',\n",
       " 'test',\n",
       " 'testing',\n",
       " 'tile',\n",
       " 'timedelta64',\n",
       " 'trace',\n",
       " 'tracemalloc_domain',\n",
       " 'transpose',\n",
       " 'trapz',\n",
       " 'tri',\n",
       " 'tril',\n",
       " 'tril_indices',\n",
       " 'tril_indices_from',\n",
       " 'trim_zeros',\n",
       " 'triu',\n",
       " 'triu_indices',\n",
       " 'triu_indices_from',\n",
       " 'true_divide',\n",
       " 'trunc',\n",
       " 'typeDict',\n",
       " 'typeNA',\n",
       " 'typecodes',\n",
       " 'typename',\n",
       " 'ubyte',\n",
       " 'ufunc',\n",
       " 'uint',\n",
       " 'uint0',\n",
       " 'uint16',\n",
       " 'uint32',\n",
       " 'uint64',\n",
       " 'uint8',\n",
       " 'uintc',\n",
       " 'uintp',\n",
       " 'ulonglong',\n",
       " 'unicode',\n",
       " 'unicode_',\n",
       " 'union1d',\n",
       " 'unique',\n",
       " 'unpackbits',\n",
       " 'unravel_index',\n",
       " 'unsignedinteger',\n",
       " 'unwrap',\n",
       " 'ushort',\n",
       " 'vander',\n",
       " 'var',\n",
       " 'vdot',\n",
       " 'vectorize',\n",
       " 'version',\n",
       " 'void',\n",
       " 'void0',\n",
       " 'vsplit',\n",
       " 'vstack',\n",
       " 'warnings',\n",
       " 'where',\n",
       " 'who',\n",
       " 'zeros',\n",
       " 'zeros_like']"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dir(np)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "not np.isnan(np.mean([]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load dataset\n",
    "dataset = Dataset(dict(dataset_name=\"CIFAR10\", data_dir=\"~/nta/datasets\",\n",
    "                          stats_mean=(0.4914, 0.4822, 0.4465),\n",
    "                            stats_std=(0.2023, 0.1994, 0.2010)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load model\n",
    "from networks import VGG19Small\n",
    "model = VGG19Small(dict(device='cpu', kwinners=True))\n",
    "loss_func = nn.CrossEntropyLoss()\n",
    "# summary(model, (3, 32, 32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 160,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([128, 3, 32, 32]), torch.Size([128]))"
      ]
     },
     "execution_count": 160,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# get one batch\n",
    "input, target = next(iter(dataset.train_loader))\n",
    "input.shape, target.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.nn.modules.conv.Conv2d'>\n",
      "<class 'nupic.torch.modules.k_winners.KWinners2d'>\n",
      "<class 'torch.nn.modules.conv.Conv2d'>\n",
      "<class 'nupic.torch.modules.k_winners.KWinners2d'>\n",
      "<class 'torch.nn.modules.conv.Conv2d'>\n"
     ]
    }
   ],
   "source": [
    "# run one forward and loss calculation\n",
    "output = model.forward(input)\n",
    "loss = loss_func(output, target)\n",
    "loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(model.correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c in model.correlations:\n",
    "    print(c.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "32*32*64, 16*16*64, 8*8*64, 4*4*128, 2*2*256, 1*1*512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [
    {
     "ename": "RuntimeError",
     "evalue": "The size of tensor a (2048) must match the size of tensor b (16384) at non-singleton dimension 1",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-141-7d8689d6ac90>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtest_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m     \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      6\u001b[0m     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      7\u001b[0m     \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/miniconda3/envs/numenta/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    491\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    492\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 493\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    494\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    495\u001b[0m             \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/nta/nupic.research/projects/dynamic_sparse/networks.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    264\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouter_product\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    265\u001b[0m                     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 266\u001b[0;31m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcorrelations\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0midx\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mouter_product\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    267\u001b[0m                 \u001b[0mprev_act\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcurr_act\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    268\u001b[0m             \u001b[0midx\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mRuntimeError\u001b[0m: The size of tensor a (2048) must match the size of tensor b (16384) at non-singleton dimension 1"
     ]
    }
   ],
   "source": [
    "# run a full epoch\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
    "for inputs, targets in dataset.test_loader:\n",
    "    optimizer.zero_grad()\n",
    "    outputs = model(inputs)\n",
    "    loss = loss_func(outputs, targets)\n",
    "    loss.backward()\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Monitor duty cycle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.2701, 0.3506, 0.3479, 0.3373, 0.3068, 0.3285, 0.2560, 0.2628, 0.3072,\n",
       "        0.3014, 0.3104, 0.3046, 0.3292, 0.3081, 0.3436, 0.2817, 0.2963, 0.2982,\n",
       "        0.2639, 0.2874, 0.3389, 0.2479, 0.2524, 0.2583, 0.2534, 0.3026, 0.3088,\n",
       "        0.3394, 0.3166, 0.3331, 0.3141, 0.2862, 0.3031, 0.3051, 0.3456, 0.3524,\n",
       "        0.3028, 0.3090, 0.3048, 0.2549, 0.3194, 0.3124, 0.2668, 0.2460, 0.2610,\n",
       "        0.3067, 0.3237, 0.3108, 0.3068, 0.2518, 0.3028, 0.3056, 0.3132, 0.3056,\n",
       "        0.3084, 0.3048, 0.3287, 0.3392, 0.2719, 0.2367, 0.2422, 0.2730, 0.3494,\n",
       "        0.2920])"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "l=kwinners[0]\n",
    "l.duty_cycle.view(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "correlations = []\n",
    "prev = None\n",
    "for l in kwinners:\n",
    "    if prev is None:\n",
    "        prev = l.duty_cycle.view(-1)\n",
    "    else:\n",
    "        nextt = l.duty_cycle.view(-1)\n",
    "        correlations.append(torch.ger(prev, nextt))\n",
    "        prev = nextt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(correlations)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(kwinners)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([64, 64])"
      ]
     },
     "execution_count": 90,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correlations[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([256, 512])"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correlations[-1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.0908, 1.0807, 1.0834, 1.0854, 1.0831],\n",
       "        [1.1194, 1.1060, 1.1095, 1.1122, 1.1092],\n",
       "        [1.1184, 1.1051, 1.1087, 1.1113, 1.1083],\n",
       "        [1.1146, 1.1018, 1.1052, 1.1078, 1.1048],\n",
       "        [1.1037, 1.0921, 1.0952, 1.0976, 1.0949]])"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.exp(correlations[0][:5,:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0869, 0.0776, 0.0801, 0.0819, 0.0798],\n",
       "        [0.1128, 0.1007, 0.1039, 0.1064, 0.1036],\n",
       "        [0.1119, 0.1000, 0.1031, 0.1056, 0.1028],\n",
       "        [0.1085, 0.0969, 0.1000, 0.1024, 0.0997],\n",
       "        [0.0987, 0.0881, 0.0909, 0.0931, 0.0907]])"
      ]
     },
     "execution_count": 101,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "correlations[0][:5,:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "a = np.array([1,2,3,4,5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([3, 4])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a[[2,3]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "list indices must be integers or slices, not list",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-8-9478b8d269fb>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0ma\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m: list indices must be integers or slices, not list"
     ]
    }
   ],
   "source": [
    "a = [1,2,3,4,5]\n",
    "a[[2,3]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
